# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "matplotlib",
#     "numpy",
#     "polars",
#     "pyarrow",
#     "seaborn",
#     "statsmodels",
# ]
# ///
"""
Reproduce figures 2, C1, C2, D1, D2, D3 of paper:
    "Information control on YouTube during Russia’s invasion of Ukraine"
Figure 1 is created in R_reproduction_figure_1.R.

Usage:
    uv run construct_figures.py
"""

from pathlib import Path
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import seaborn as sns
import statsmodels.api as sm

WAR_DATE = datetime(year=2022, month=2, day=24)
YTBAN_DATE = datetime(year=2022, month=3, day=4)


def plot_comments(gpb: pl.DataFrame, ax):
    """plot comment activity"""
    sns.set_context(
        "paper",
        rc={
            "title.size": 14,
            "font.size": 13,
            "axes.titlesize": 13,
            "axes.labelsize": 13,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 12,
            "legend.title_fontsize": 12,
        },
    )
    sns.lineplot(
        data=gpb.with_columns(
            pl.col("type").replace_strict(
                {
                    "pro_kremlin": (s := "Pro-Kremlin (non-blocked)"),
                    "anti_kremlin": "Anti-Kremlin",
                    "entertainment": "Entertainment",
                },
            ),
            pl.col("comment_id").truediv(10_000),
        ).to_pandas(),
        x="publish_date",
        y="comment_id",
        hue="type",
        marker="o",
        ax=ax,
        palette={
            s: "#d62728",  # red
            "Anti-Kremlin": "#1f77b4",  # blue
            "Entertainment": "#7f7f7f",  # gray
        },
    )
    ax.set(
        xlabel="Days since Youtube Ban",
        ylabel="Total comments (in 10k)",
    )
    ax.axvline(
        x=0,
        color="black",
        linestyle="--",
        label="Youtube Ban",
    )
    ax.set_xticks(ticks=np.arange(-40, 41, 10))
    leg = ax.legend(title="Activity on channels", frameon=False)
    # leg.get_title().set_ha("left")
    leg.set_alignment("left")
    ax.spines[["top", "right"]].set_visible(False)


def plot_comments_curve(gp_c, ax):
    df = gp_c.to_pandas()
    color_map = {
        "Before ban": "#1f77b4",  # blue
        "After ban": "#d62728",  # red
    }
    for period_value, df_sub in df.groupby("period"):
        df_sub_sorted = df_sub.sort_values("publish_date")
        x = df_sub_sorted["publish_date"]
        y = df_sub_sorted["comment_id"].truediv(10_000)

        ax.plot(
            x,
            y,
            marker="o",
            linestyle="-",
            label=period_value,
            color=color_map[period_value],
        )
        X = sm.add_constant(x)
        model = sm.OLS(y, X).fit()
        y_pred = model.predict(X)
        ax.plot(x, y_pred, linestyle="-", color="black", alpha=0.5)
    ax.set(
        xlabel="Days since Youtube Ban",
        ylabel="Total comments (in 10k)",
    )
    ax.axvline(
        x=0,
        color="black",
        linestyle="--",
        label="Youtube Ban",
    )
    ax.legend(frameon=False)
    ax.spines[["top", "right"]].set_visible(False)
    ax.set_xticks(ticks=np.arange(-40, 41, 10))


def plot_activity(
    df: pl.DataFrame,
    ax,
    ycol: str = "comment_id",
    pro_kremlin_c: str = "Pro-Kremlin",
    div_factor: int = 1_000,
):
    """plot comment activity"""
    sns.set_context(
        "paper",
        rc={
            "title.size": 14,
            "font.size": 13,
            "axes.titlesize": 13,
            "axes.labelsize": 13,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 12,
            "legend.title_fontsize": 12,
        },
    )
    sns.lineplot(
        data=df.with_columns(
            pl.col("type").replace_strict(
                {
                    "pro_kremlin": pro_kremlin_c,
                    "anti_kremlin": "Anti-Kremlin",
                    "entertainment": "Entertainment",
                },
            ),
            pl.col(ycol).truediv(div_factor),
        ).to_pandas(),
        x="publish_date",
        y=ycol,
        hue="type",
        marker="o",
        ax=ax,
        palette={
            pro_kremlin_c: "#d62728",  # red
            "Anti-Kremlin": "#1f77b4",  # blue
            "Entertainment": "#7f7f7f",  # gray
        },
    )
    ax.axvline(
        x=0,
        color="black",
        linestyle="--",
        label="Youtube Ban",
    )
    ax.set_xticks(ticks=np.arange(-40, 41, 10))
    ax.legend(frameon=False, title="Channel type")
    ax.spines[["top", "right"]].set_visible(False)


def plot_facet(gp_plot: pl.DataFrame):
    sns.set_context(
        "paper",
        rc={
            "title.size": 14,
            "font.size": 13,
            "axes.titlesize": 13,
            "axes.labelsize": 13,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 12,
            "legend.title_fontsize": 12,
        },
    )
    g = sns.FacetGrid(
        data=gp_plot.with_columns(
            pl.col("comment_id").truediv(1_000),
        ).to_pandas(),
        col="channel",
        col_wrap=3,
        sharey=False,
        height=4.0,  # increase subplot height (default is 3)
        aspect=1.4,  # width/height ratio (default is 1)
    )
    # add vertical line to each subplot
    relative_war = (WAR_DATE - YTBAN_DATE).days
    for ax in g.axes.flat:
        ax.axvline(x=0, color="red", linestyle="--", label="Youtube Ban")
        ax.axvline(
            x=relative_war,
            color="blue",
            linestyle=":",
            label="Russia's invasion of Ukraine",
        )
    g.map(sns.lineplot, "publish_date", "comment_id", marker="o", color="black")
    g.add_legend()
    g.set_axis_labels(
        "Days since Youtube Ban",
        "Total comments (in 1k)",
    )
    g.set(xticks=np.arange(-40, 41, 10))
    g.tight_layout()

    return g


def plot_barplot(post_block: pl.DataFrame):
    # Figure
    channels = [
        "ren-tv",
        "ria-news",
        "rossiya-1",
        "rt-на-русском",
        "sputnik",
        "звезда",
        "россия-24",
        "телеканал-360",
        "царьград-тв",
    ]
    f, ax = plt.subplots(3, 3, figsize=(45, 35))
    b = [0, 1, 2, 0, 1, 2, 0, 1, 2]
    a = [0, 0, 0, 1, 1, 1, 2, 2, 2]
    for i, chan in enumerate(channels):
        sns.barplot(
            x="publish_date",
            y="N",
            color="green",
            data=post_block.filter([pl.col("channel").eq(chan)]),
            ax=ax[a[i], b[i]],
        )
        ax[a[i], b[i]].set_xlabel("")
        ax[a[i], b[i]].set_ylabel("Number of comments")
        ax[a[i], b[i]].set_title(chan, size=40)
        ax[a[i], b[i]].tick_params(axis="x", labelrotation=90, labelsize=23)
    f.tight_layout()
    return f


def main():
    fp_data = Path.cwd().joinpath("data-replication")
    fp_figs = Path.cwd().joinpath("figures")
    fp_figs.mkdir(exist_ok=True)

    # Load data
    fig2_data = pl.read_csv(fp_data / "fig2_data.csv")
    figc1_data = pl.read_csv(fp_data / "figc1_data.csv")
    figc2_data = pl.read_csv(fp_data / "figc2_data.csv")
    figd1_data = pl.read_csv(fp_data / "figd1_data.csv")
    figd2_data = pl.read_csv(fp_data / "figd2_data.csv")
    figd3_data = pl.read_csv(fp_data / "figd3_data.csv")

    # Figure 2
    fig, ax = plt.subplots(figsize=(10, 6))
    plot_comments(fig2_data, ax)
    fig.savefig(fp_figs / "figure-2.pdf", dpi=300)

    # Figure C1
    f = plot_barplot(figc1_data)
    f.savefig(fp_figs / "figure-c1.pdf")

    # Figure C2
    g = plot_facet(figc2_data)
    g.savefig(fp_figs / "figure-c2.pdf")

    # Figure D1
    fig, ax = plt.subplots(figsize=(10, 6))
    plot_comments_curve(figd1_data, ax)
    fig.savefig(fp_figs / "figure-d1.pdf", dpi=300)

    # Figure D2
    fig, ax = plt.subplots(figsize=(10, 6))
    plot_activity(
        figd2_data,
        ax,
        ycol="video_id",
        pro_kremlin_c="Pro-Kremlin (non-blocked)",
    )
    ax.set(
        xlabel="Days since Youtube Ban",
        ylabel="Total videos (in 1k)",
    )
    fig.savefig(fp_figs / "figure-d2.pdf", dpi=300)

    # Figure D3
    fig, ax = plt.subplots(figsize=(10, 6))
    plot_activity(
        figd3_data,
        ax,
        ycol="video_id",
        pro_kremlin_c="Pro-Kremlin (non-blocked)",
    )
    ax.set(
        xlabel="Days since Youtube Ban",
        ylabel="Total videos (in 1k)",
    )
    fig.savefig(fp_figs / "figure-d3.pdf", dpi=300)


if __name__ == "__main__":
    main()
